"""Control profile to control H1 in floating mode."""
from typing import Optional

import numpy as np
from gymnasium.core import ActType
from pyquaternion import Quaternion
from xr import Posef

from bigym.bigym_env import BiGymEnv
from vr.ik.h1_upper_body_ik import H1UpperBodyIK, Pose
from vr.viewer import Side
from vr.viewer.control_profiles.control_profile import ControlProfile
from vr.viewer.pyopenxr_to_mujoco_converter import (
    vector_from_pyopenxr,
    quaternion_from_pyopenxr,
)
from vr.viewer.xr_context import XRContextObject


class StandardControlProfile(ControlProfile):
    """Control profile for H1 in floating mode.

    Notes:
        - Use controller triggers to control the grippers.
        - Use the right A button to enable/disable synchronization of position.
        - Use the right B button to enable/disable synchronization of rotation.
        - Position of controllers is used as the target for the corresponding arm of H1.
    """

    POSITION_SMOOTHING = 0.01
    ROTATION_SMOOTHING = 0.01
    VERTICAL_OFFSET = 0.7
    HMD_PIVOT_OFFSET = np.array([0, -0.2, 0])

    def __init__(self):
        """Init."""
        super().__init__()

        self._prev_hmd_pos: Optional[np.ndarray] = None
        self._prev_hmd_rot: Optional[Quaternion] = None
        self._sync_position = True
        self._sync_rotation = True

        # Will be initialized later
        self._ik: Optional[H1UpperBodyIK] = None

    def bind_environment(self, env: BiGymEnv):
        """Bind environment and initialize IK solver."""
        super().bind_environment(env)
        self._ik = H1UpperBodyIK(env)

    def get_next_action(
        self, context: XRContextObject, steps_predicted: int, space_offset: Posef
    ) -> ActType:
        """See base."""
        gripper_left = context.input.state[Side.LEFT].trigger_value
        left_pose = context.input.state[Side.LEFT].pose_aim
        left_pos = (
            vector_from_pyopenxr(left_pose.position) + space_offset.position.as_numpy()
        )
        left_quat = Quaternion(quaternion_from_pyopenxr(left_pose.orientation))

        gripper_right = context.input.state[Side.RIGHT].trigger_value
        right_pose = context.input.state[Side.RIGHT].pose_aim
        right_pos = (
            vector_from_pyopenxr(right_pose.position) + space_offset.position.as_numpy()
        )
        right_quat = Quaternion(quaternion_from_pyopenxr(right_pose.orientation))

        hmd_pose = context.input.hmd_pose
        hmd_rot = Quaternion(quaternion_from_pyopenxr(hmd_pose.orientation))
        hmd_pos = (
            vector_from_pyopenxr(hmd_pose.position) + space_offset.position.as_numpy()
        )
        hmd_pos += hmd_rot.rotate(self.HMD_PIVOT_OFFSET)

        # Toggle position and rotation sync
        if (
            context.input.state[Side.RIGHT].a_click
            and context.input.state[Side.RIGHT].a_toggle
        ):
            self._sync_position = not self._sync_position
        if (
            context.input.state[Side.RIGHT].b_click
            and context.input.state[Side.RIGHT].b_toggle
        ):
            self._sync_rotation = not self._sync_rotation

        pelvis = self._env.robot.pelvis
        pelvis_pose = Pose(pelvis.get_position(), Quaternion(pelvis.get_quaternion()))

        delta_position = np.array(hmd_pos - pelvis_pose.position)
        delta_position[2] -= self.VERTICAL_OFFSET
        magnitude = np.linalg.norm(delta_position)
        if magnitude > 1:
            delta_position /= magnitude
        delta_position *= self.POSITION_SMOOTHING * float(self._sync_position)

        delta_rotation = (
            hmd_rot
            * pelvis_pose.orientation.inverse
            * Quaternion(axis=[0, 0, 1], angle=np.pi / 2)
        )
        delta_rotation = np.flip(np.array(delta_rotation.yaw_pitch_roll))
        delta_rotation *= self.ROTATION_SMOOTHING * float(self._sync_rotation)

        low = self._env.action_space.low
        control = np.zeros_like(low)

        # Control floating base
        floating_base = self._env.robot.floating_base
        floating_base_control = []
        if floating_base:
            for delta, actuator in zip(
                delta_position, floating_base.position_actuators
            ):
                if actuator:
                    floating_base_control.append(delta)
            for delta, actuator in zip(
                delta_rotation, floating_base.rotation_actuators
            ):
                if actuator:
                    floating_base_control.append(delta)
        for i, ctrl in enumerate(floating_base_control):
            control[i] = ctrl

        # Control arms
        start_index = floating_base.dof_amount
        end_index = start_index + len(self._env.robot.limb_actuators)

        arms_qpos = np.array(self._env.robot.qpos_actuated[start_index:end_index])
        qpos_arm_left, qpos_arm_right = np.split(arms_qpos, 2)
        solution = self._ik.solve(
            pelvis_pose=pelvis_pose,
            qpos_arm_left=qpos_arm_left,
            qpos_arm_right=qpos_arm_right,
            target_pose_left=Pose(left_pos, left_quat),
            target_pose_right=Pose(right_pos, right_quat),
        )
        control[start_index:end_index] = solution

        # Control grippers
        control[-2] = np.clip(np.round(gripper_left), 0, 1)
        control[-1] = np.clip(np.round(gripper_right), 0, 1)

        self._prev_hmd_pos = hmd_pos
        self._prev_hmd_rot = hmd_rot
        return control
